Skip to content

Vectorize discrimination_score; scale DE overlap/counts/AUC metrics to high perturbation counts#238

Open
FarzanT wants to merge 5 commits into
ArcInstitute:mainfrom
FarzanT:perf/vectorize-discrimination-score
Open

Vectorize discrimination_score; scale DE overlap/counts/AUC metrics to high perturbation counts#238
FarzanT wants to merge 5 commits into
ArcInstitute:mainfrom
FarzanT:perf/vectorize-discrimination-score

Conversation

@FarzanT
Copy link
Copy Markdown

@FarzanT FarzanT commented May 27, 2026

This PR contains three independent, bit-exact performance fixes to the DE /
anndata metric pipeline. All leave outputs numerically identical; none requires
re-baselining metrics. They were found in sequence: each fix cleared the
reigning bottleneck on a ~18k-perturbation screen and exposed the next.

  1. Vectorize discrimination_score (anndata-pair metric) — closes discrimination_score recomputes an n_pert x n_pert distance matrix one row at a time #237.
  2. Memoize the DE overlap rank-matrix pivot and set-ify column membership in
    compute_overlap / get_top_genes (DE metric).
  3. Replace the per-perturbation full-table scans in DENsigCounts and
    compute_generic_auc (pr/roc) with a single grouped/partitioned slice
    (DE metrics).

Fix 1 — Vectorize discrimination_score (closes #237)

What

discrimination_score looped over n_pert perturbations, calling
pairwise_distances once per perturbation to compute a single row of an
n_pert x n_pert distance matrix. This computes the full matrix in one call
and ranks each perturbation by locating its column's position in the per-row
sorted order.

The target-gene-exclusion path (default for expression data) drops a different
feature column per perturbation, so a single unmasked call can't reproduce it.
The full matrix is computed once and corrected per row with an exact,
vectorized rank-1 update that removes the target gene's contribution:

  • l1: subtract |pred_g - real_g|
  • l2: sqrt(d^2 - (pred_g - real_g)^2)
  • cosine: drop the column from the dot product and both norms (masked
    squared norms are clipped at 0 so a target-gene-dominated effect can't round
    negative into a NaN)

Metrics without a closed-form correction fall back to exact per-row masked
distances, and duplicate gene names matching a perturbation are handled by an
exact per-row safety net.

Ranking uses np.where(order == arange[:, None]) to find each row's matching
column rather than a second argsort over the full matrix — identical result,
but it avoids an extra O(n_pert^2) int64 array and a second sort (thanks to
the review suggestion). At n_pert=10000 the ranking step alone drops from
~9.8s / 2.4 GB to ~5.7s / 1.7 GB, a large share of the l2/cosine runtime.

Parity (no behavior change)

Output is numerically identical to the original loop. Across 216 synthetic
configurations (metric x exclude_target_gene x embed_key x seed x
targeting-fraction) the normalized ranks match bit-for-bit (worst
|delta-rank| = 0), and the equivalence is covered by the unit tests.

Benchmark

Apple M2 Pro, Python 3.12.12, numpy 2.4.6 / scipy 1.17.1 / scikit-learn 1.8.0,
n_genes=2000 (old = original per-perturbation loop, new = vectorized):

n_pert metric old (s) new (s) speedup
100 l1 0.053 0.013 4.1x
100 l2 0.044 0.005 9.4x
100 cosine 0.081 0.005 15.6x
1000 l1 3.703 1.028 3.6x
1000 l2 2.460 0.119 20.7x
1000 cosine 5.285 0.135 39.3x
2000 l1 22.015 4.143 5.3x
2000 l2 17.598 0.433 40.7x
2000 cosine 27.730 0.455 61.0x
10000 l1 978.5 114.2 8.6x
10000 l2 868.4 12.3 70.5x
10000 cosine 1164.5 14.7 79.2x

(n_pert=10000 old times are from the same reference loop on the same machine;
the loop is unchanged by this PR.)

l2/cosine use the BLAS dot-product trick, so one matrix multiply replaces
n_pert dispatched calls and the speedup grows with n_pert; l1 (manhattan)
is a non-BLAS kernel, so its gain is the removed per-call dispatch overhead.
Trade-off: the full matrix is O(n_pert^2) memory vs O(n_pert) per iteration
before — comfortable for typical screen sizes.

Reproduce: python benchmarks/bench_discrimination_score.py


Fix 2 — DE overlap metric: memoize the rank-matrix pivot + O(1) column membership

cell_eval/_types/_de.pyDEComparison.compute_overlap / DEResults.get_top_genes.

Motivation

An OmniPert report_only validate-only run over a HuangChu context with
~18,000 perturbations (16 CPU, 24h wall) ran for 8.8 hours and never
reached results.csv
. Faulthandler sampling of the live process (524 periodic
stack dumps) put 513/524 ≈ 98% of all samples inside a single metric:

de_overlap_metric -> DEComparison.compute_overlap -> DEResults.get_top_genes
  hottest leaf: polars frame.py ... in `columns`

The DE compute itself (pdex) was already fast (~135s); this is purely the
metric pass, and it ate the entire budget before the rest of the DE metrics
and the anndata-pair metrics even started.

Two compounding inefficiencies

1. Redundant rank-matrix pivots (the big one). compute_overlap calls
self.real.get_top_genes(sort_by, fdr_threshold) and self.pred.get_top_genes(...)
at the top of every invocation. get_top_genes builds a polars .pivot() with
one column per perturbation (~18k-wide). The de/full profile registers 10
variants
of this metric — metrics/_impl.py does
for metric in ["overlap","precision"]: for n in [None,50,100,200,500] with
kwargs={"k": n, "metric": metric} and the default sort_by / fdr_threshold.
So the identical pair of 18k-column pivots is rebuilt 10 times; k only
affects the per-pert truncation genes[:k_eff] inside the loop, never the
matrix.

Fix: memoize get_top_genes keyed by (sort_by, fdr_threshold) on the
per-side DEResults (a dataclass field excluded from init/repr/eq).
DESortBy is an enum (hashable), so the key is safe.

2. O(perts²) column membership. The per-pert loop did
if pert not in real_sig_rank_matrix.columns or pert not in pred_sig_rank_matrix.columns:.
polars rebuilds a fresh ~18k-element list on every .columns access, twice
per pert, over ~18k perts (~3×10⁸ list rebuilds — this is the columns leaf the
sampler kept hitting).

Fix: precompute real_cols = set(real_sig_rank_matrix.columns) and
pred_cols = set(...) once before the loop and test against the sets. (Column
selection matrix[pert] is unchanged.)

Impact

The overlap pass drops from 10 redundant wide pivots + an O(perts²) membership
loop
to 2 pivots + an O(perts) loop. What remains irreducible is one pivot
per side plus the per-pert intersect; everything the fix removes scales with
perturbation count.

Measured with benchmarks/bench_de_overlap.py (Apple M2 Pro, polars 1.x; one
full 10-variant pass over synthetic real/pred DE tables, all outputs verified
bit-identical between old and new):

n_genes=2000, n_sig=100 significant genes/pert:

n_pert old (s) new (s) speedup
1000 2.34 1.05 2.2x
2000 6.79 2.18 3.1x
4000 22.32 4.54 4.9x
8000 80.84 9.78 8.3x

The old path grows ~quadratically and the new path ~linearly, so the speedup
climbs with n_pert — the axis that matters for the ~18k-perturbation screens
that motivated this. With more significant genes per pert the absolute cost
rises (n_sig=500: 8000 perts = 166.8s old vs 34.0s new), since the irreducible
single-pivot + intersect work grows while the removed redundancy stays fixed.

This is the metric that consumed an entire 8.8h validate-only budget at ~18k
perts (98% of samples in this exact call stack — see issue below); the measured
scaling above is the mechanism behind that, reproduced at tractable sizes.

Reproduce: python benchmarks/bench_de_overlap.py --n-pert 1000 2000 4000 8000

Parity (no behavior change)

Both changes are pure performance; outputs are bit-identical. New
tests/test_de_overlap_equivalence.py:

  • compute_overlap matches a from-scratch reference across k ∈ {None,1,2,50,500}
    and metric ∈ {overlap, precision} (exact == on the result dicts).
  • the cache collapses the 10 registered variants to one pivot per side
    (len(_top_genes_cache) == 1 after all variants run).
  • distinct (sort_by, fdr_threshold) keys stay separate (3 keys → 3 entries).
  • the no-significant-genes early-return path still returns all-zeros.

The existing end-to-end tests/test_eval.py (which runs the de/full profile
through the registry, exercising all 10 overlap variants) is unchanged and
still passes.


Fix 3 — DE per-perturbation full-table scans (DENsigCounts, pr/roc AUC)

cell_eval/metrics/_de.pyDENsigCounts.__call__ / compute_generic_auc.

Motivation

With Fix 2 in place, the same ~18k-perturbation OmniPert report run cleared the
overlap pass (the phase that previously stalled 8.6h now completes) and exposed
the next instance of the same per-perturbation-scan pathology downstream.
Faulthandler sampling of the live process (job over a 371k-row DE table) now sat
continuously in:

DENsigCounts.__call__ -> DEResults.get_significant_genes -> polars filter().collect()

Two metrics share the shape — a for pert in iter_perturbations() loop whose
body does a full-table .filter(target == pert):

  1. DENsigCounts called get_significant_genes(pert) for the real and
    pred side per perturbation — each a full scan of the whole DE table, so
    ~18k × 2 ≈ 36k scans, where only the significant-gene count is used.
  2. compute_generic_auc (backing both pr_auc and roc_auc) built its
    merged frame once (good) but then did merged.filter(target == pert) per
    perturbation — ~18k more full scans, ×2 for the two AUC variants.

Fix

  • DENsigCounts: one filter_to_significant().group_by(target).len() per
    side, then reindex over the full iter_perturbations() universe filling 0
    for perts with no significant genes (matching the old empty .size). Only the
    count is consumed downstream, so this is exact.
  • compute_generic_auc: one merged.partition_by(target, maintain_order=True, as_dict=True) before the loop; iterate the dict. Perts absent from the map →
    nan (matching the old shape[0] == 0 branch). maintain_order keeps each
    partition in the exact row order the per-pert .filter produced, so the
    labels/scores arrays handed to average_precision_score / roc_curve are
    bit-identical — the per-pert sklearn calls are untouched, only the slicing
    changed from O(perts × rows) to O(rows) total. (partition_by(as_dict=True)
    keys are tuples on newer polars and scalars on older, so they're normalized.)

The audit covered the rest of metrics/_de.py: DESpearmanSignificant,
DESpearmanLFC, DEDirectionMatch, and DESigGenesRecall already use a single
group_by/join (no per-pert loop) and are unchanged.

Parity (no behavior change)

Pure performance; outputs bit-identical. New
tests/test_de_perpert_scan_equivalence.py reproduces the pre-optimization
DENsigCounts and compute_generic_auc verbatim and asserts the new code
matches across a synthetic multi-pert DEComparison — including the
all-significant / all-non-significant perturbations that map to nan, and the
zero-significant reindex path.

Benchmark

benchmarks/bench_de_metrics.py (Apple M2 Pro, polars 1.41; verbatim pre-fix
baselines vs new, all outputs verified bit-identical), n_genes=50 (~the
production DE table's rows-per-pert):

n_pert rows nsig old→new nsig speedup pr old→new pr speedup roc speedup
1000 50k 0.295→0.002s 130x 0.460→0.264s 1.7x 1.6x
2000 100k 0.626→0.016s 40x 0.856→0.506s 1.7x 1.7x
4000 200k 1.532→0.018s 83x 1.724→0.968s 1.8x 1.8x
8000 400k 3.944→0.023s 172x 3.596→1.985s 1.8x 1.8x

DENsigCounts — the metric the faulthandler actually sat in — was pure
table-scan, so removing it gives 40–172x and the win grows with scale (old is
O(perts × rows), new O(rows)). The pr/roc AUC metrics get a steady ~1.8x: the
per-perturbation sklearn calls and the shared merged build are irreducible,
so only the O(perts × rows) slicing was removed (the absolute time saved still
grows with scale — 0.2s → 1.6s from 1k → 8k perts).

Reproduce: python benchmarks/bench_de_metrics.py --n-pert 1000 2000 4000 8000

Not addressed here (noted for scale)

The anndata-pair metric clustering_agreement remains unprofiled at 18k
perts
— a separate potential cliff not yet reached. discrimination_score is
already vectorized (Fix 1). Out of scope for this PR.


Tests / checks

  • tests/test_discrimination_score.py: equivalence vs the original loop,
    covering the exclusion, no-exclusion (embedding), exotic-metric fallback,
    duplicate-gene safety net, and target-gene-dominated (degenerate cosine) paths.
  • tests/test_de_overlap_equivalence.py: bit-exact overlap/precision
    equivalence + memoization guards (Fix 2).
  • tests/test_de_perpert_scan_equivalence.py: bit-exact DENsigCounts /
    pr_auc / roc_auc equivalence vs verbatim pre-optimization references (Fix 3).
  • benchmarks/bench_discrimination_score.py, benchmarks/bench_de_overlap.py,
    and benchmarks/bench_de_metrics.py: self-documenting old-vs-new
    microbenchmarks (each keeps a verbatim pre-optimization baseline and asserts
    identical output).
uv run pytest -v          # 89 passed
uv run ruff format --check
uv run ruff check
uv run ty check

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request vectorizes the discrimination_score metric, replacing the slow per-perturbation loop with an optimized matrix computation that includes closed-form column corrections for L1, L2, and cosine distances. It also adds a microbenchmark script and comprehensive equivalence tests. The reviewer suggested optimizing the ranking step by replacing the double argsort with a vectorized comparison (np.where) to find the diagonal element's rank, reducing complexity from $O(N^2 \log N)$ to $O(N^2)$ and saving memory.

Comment thread src/cell_eval/metrics/_anndata.py Outdated
Comment on lines +196 to +201
# Rank of the matching perturbation within each row, by ascending distance.
# argsort(argsort(row)) is the inverse permutation, i.e. the rank of every
# column; the diagonal entry is the rank of the correct perturbation.
n_pert = data.perts.size
order = np.argsort(dist_matrix, axis=1)
ranks = np.argsort(order, axis=1)[np.arange(n_pert), np.arange(n_pert)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation uses a double argsort (np.argsort(np.argsort(dist_matrix, axis=1))) to find the rank of the diagonal element (the matching perturbation) in each row. Sorting the entire order matrix a second time is computationally expensive ($O(N^2 \log N)$) and requires significant memory ($O(N^2)$ integers).

Since we only need the rank of the diagonal element $i$ in each row $i$, we can find its column index in the order matrix using a vectorized comparison: np.where(order == np.arange(n_pert)[:, None])[1]. This avoids the second sorting step entirely, reducing the ranking overhead from $O(N^2 \log N)$ to $O(N^2)$ simple comparisons, which is about 8x faster and uses up to 8x less memory for the ranking step, while maintaining 100% identical tie-breaking behavior.

Suggested change
# Rank of the matching perturbation within each row, by ascending distance.
# argsort(argsort(row)) is the inverse permutation, i.e. the rank of every
# column; the diagonal entry is the rank of the correct perturbation.
n_pert = data.perts.size
order = np.argsort(dist_matrix, axis=1)
ranks = np.argsort(order, axis=1)[np.arange(n_pert), np.arange(n_pert)]
# Rank of the matching perturbation within each row, by ascending distance.
# Instead of a second argsort, we can find the column index of the diagonal
# element in the sorted order matrix using a vectorized comparison.
n_pert = data.perts.size
order = np.argsort(dist_matrix, axis=1)
ranks = np.where(order == np.arange(n_pert)[:, None])[1]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — adopted. The diagonal of argsort(argsort(.)) is just the position of column i within row i, which the boolean match finds directly, so the result is identical; I verified equality across random seeds and a heavy-ties matrix.

I benchmarked the claim on an Apple M2 Pro at n_pert=10000:

  • ranking step: 9.8s → 5.7s (~1.7x), peak RSS 2.42 GB → 1.74 GB.
  • The ~8x is accurate for the second argsort intermediate specifically (an 800 MB int64 array replaced by a ~100 MB bool mask); end-to-end the ranking step is ~1.7x because the first argsort is shared by both approaches.

Since the ranking runs once per metric, dropping the second sort helped the BLAS-bound metrics most end-to-end: l2@10k 51x → 70x, cosine@10k 61x → 79x. Benchmark table in the PR description updated.

discrimination_score previously looped over n_pert perturbations, calling
pairwise_distances once per perturbation to compute a single row of an
n_pert x n_pert distance matrix. This replaces the loop with a single
full-matrix computation, then ranks each perturbation by locating its column's
position in the per-row sorted order.

The target-gene-exclusion path (the default for expression data) drops a
different feature column per perturbation, so a single unmasked pairwise call
cannot reproduce it. The full matrix is computed once and corrected per row
with an exact, vectorized rank-1 update that removes the target gene's
contribution (l1: subtract |delta|; l2: sqrt(d^2 - delta^2); cosine: drop the
column from the dot product and both norms). Metrics without a closed-form
column correction fall back to exact per-row masked distances, and duplicate
gene names matching one perturbation are handled by an exact per-row net.

Ranking uses a boolean match (np.where(order == arange)) to find each row's
matching column, rather than a second argsort over the full matrix: identical
result, but it avoids an extra O(n_pert^2) int64 array and a second sort. At
n_pert=10000 the ranking step alone drops from ~9.8s / 2.4 GB to ~5.7s /
1.7 GB, which is a large share of the l2/cosine runtime.

Output is numerically identical: across 216 synthetic configurations
(metric x exclude x embed_key x seed x targeting-fraction) the normalized
ranks match the original loop bit-for-bit (worst |delta-rank| = 0).

Measured speedups (Apple M2 Pro, Python 3.12, numpy 2.4 / scipy 1.17 /
scikit-learn 1.8; n_genes=2000; ranks identical to the loop at every point):

  n_pert   l1     l2      cosine
  100      4.1x   9.4x    15.6x
  1000     3.6x   20.7x   39.3x
  2000     5.3x   40.7x   61.0x
  10000    8.6x   70.5x   79.2x

l2/cosine use the BLAS dot-product trick, so a single matrix multiply replaces
n_pert dispatched calls and the advantage grows with n_pert; l1 (manhattan) is
a non-BLAS kernel, so its gain comes from removing per-call dispatch overhead.
Memory is O(n_pert^2) for the full matrix, vs O(n_pert) per iteration before;
this is the cost of vectorization and is comfortable for typical screen sizes.

The cosine column correction clips masked squared norms at zero before the
square root: an effect dominated by its target gene can round the masked norm
slightly negative, which would otherwise yield NaN distances.

Adds tests/test_discrimination_score.py (equivalence vs the original loop,
covering the exclusion, exotic-metric, duplicate-gene, and target-gene-
dominated paths) and benchmarks/bench_discrimination_score.py.
@FarzanT FarzanT force-pushed the perf/vectorize-discrimination-score branch from 4bd7aff to b1825e0 Compare May 27, 2026 16:42
FarzanT added 2 commits May 28, 2026 08:09
compute_overlap rebuilt the per-side rank matrix on every call: it invokes
get_top_genes(sort_by, fdr_threshold) for the real and pred sides, and that
builds a polars .pivot() with one column per perturbation. The de/full profile
registers 10 overlap variants (overlap/precision x k in {None,50,100,200,500}),
all with the same default sort_by and fdr_threshold, so the identical wide pivot
was rebuilt 10 times per side -- k only truncates the per-pert gene list
downstream, never the matrix itself.

Memoize get_top_genes on the DEResults instance keyed by (sort_by,
fdr_threshold) so the pivot is built once per side and reused across all
variants. The cache is a dataclass field excluded from init/repr/eq.

Also hoist the rank-matrix column names into sets once before the per-pert loop.
polars rebuilds a fresh column list on every .columns access, so the two
"pert not in matrix.columns" membership tests were O(n_perts) each, making the
loop O(n_perts^2); at ~18k perturbations this dominated the metric pass.

Both changes are pure performance: outputs are bit-identical. Adds
tests/test_de_overlap_equivalence.py asserting compute_overlap matches a
from-scratch reference across k and metric, that the cache collapses the 10
variants to one pivot per side, and that distinct (sort_by, fdr_threshold) keys
stay separate.
Mirrors bench_discrimination_score.py: keeps a verbatim copy of the
pre-memoization get_top_genes / compute_overlap as the baseline, runs the full
10-variant overlap/precision pattern across a sweep of perturbation counts, and
asserts old and new produce identical results. Measures the redundant-pivot +
O(perts^2)-membership removal: speedup grows from ~2.2x at 1k perts to ~8.3x at
8k (n_sig=100), tracking the ~quadratic old path vs ~linear new path.
@FarzanT FarzanT changed the title Vectorize discrimination_score distance computation Vectorize discrimination_score; memoize DE overlap rank-matrix pivot May 28, 2026
DENsigCounts looped over every perturbation calling get_significant_genes(pert)
for the real and pred sides, and compute_generic_auc (pr/roc) looped calling
merged.filter(target == pert). Each is a full-table scan per perturbation --
O(n_pert * n_rows) -- so at ~18k perturbations over a 371k-row DE table they
dominate the metric pass (the sampler sat in get_significant_genes ->
filter().collect() and in the per-pert filter for the whole sampling window).

Replace both with a single slice:
- DENsigCounts: one filter_to_significant().group_by(target).len() per side,
  reindexed over the full perturbation universe (0 for perts with no
  significant genes), since only the count is used.
- compute_generic_auc: one merged.partition_by(target, maintain_order=True)
  before the loop; perts absent from the partition map -> nan, matching the old
  empty-slice branch. maintain_order keeps each partition in the row order the
  per-pert filter produced, so the labels/scores handed to
  average_precision_score / roc_curve are bit-identical. partition_by(as_dict)
  keys are tuples on newer polars and scalars on older, so normalize to str.

Pure performance; output is bit-identical. The sibling DE metrics
(DESpearmanSignificant/LFC, DEDirectionMatch, DESigGenesRecall) already use the
group_by/join shape and are unchanged. Adds
tests/test_de_perpert_scan_equivalence.py with verbatim pre-optimization
references for both metrics.
@FarzanT FarzanT changed the title Vectorize discrimination_score; memoize DE overlap rank-matrix pivot Vectorize discrimination_score; scale DE overlap/counts/AUC metrics to high perturbation counts May 28, 2026
Mirrors bench_de_overlap.py: verbatim pre-optimization DENsigCounts and
compute_generic_auc baselines vs the current single-slice versions, swept over
perturbation counts with outputs asserted identical. DENsigCounts (the metric
the faulthandler sat in) drops 40-175x and grows with scale; the pr/roc AUC
metrics get a steady ~1.8x because the per-pert sklearn calls and the shared
merged build are irreducible -- only the O(perts x rows) table slicing was
removed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

discrimination_score recomputes an n_pert x n_pert distance matrix one row at a time

1 participant